import torch
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from swarms_torch.structs.simple_moe import SimpleMoE # Dependency: swarms-torch
from torch import Tensor, nn
from zeta.nn import Attention, OutputHead 

# Helper function
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class DenseEncoderLayer(nn.Module):
    def __init__(
        self,
        dim: int,
        heads: int,
        num_experts: int,
        dim_head: int,
        dropout: float,
        ff_mult: int,
        top_k: int, 
        qk_norm_attn: bool = True,
        flash_attention: bool = True,
    ):
        super().__init__()
        self.dim = dim
        self.num_experts = num_experts
        self.dim_head = dim_head
        self.dropout_rate = dropout
        self.ff_mult = ff_mult
         
        calculated_heads = dim // dim_head if dim_head > 0 else heads
        self.actual_heads = heads  
        device_for_flash = "cuda" if torch.cuda.is_available() else "cpu"

        
        self.experts = SimpleMoE(
            dim=dim,
            hidden_dim=dim * self.ff_mult,
            output_dim=dim,
            num_experts=num_experts
            
        )

        self.attn = Attention(
            dim=dim,
            dim_head=dim_head,
            heads=self.actual_heads,
            flash=(flash_attention and device_for_flash == "cuda"),
            qk_norm=qk_norm_attn
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim) # Post-MoE norm
        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, x: Tensor):
        # Pre-Attention Norm + Attention + Residual + Dropout
        normed_x = self.norm1(x)
        attn_out, _ = self.attn(normed_x)
        x = x + self.dropout_layer(attn_out)

        # Pre-MoE Norm + MoE Experts + Residual + Dropout
        normed_x = self.norm2(x)
        expert_out = self.experts(normed_x)
        x = x + self.dropout_layer(expert_out)
        
        # Post-MoE Norm
        x = self.norm3(x)
        return x

class LiMoE(nn.Module):
    def __init__(
        self,
        dim: int,
        depth: int,
        heads: int,
        num_tokens: int,
        seq_length: int,
        num_experts: int,
        dim_head: int,
        dropout: float,
        ff_mult: int,
        patch_size: int,
        image_size: int,
        channels: int,
        dense_encoder_depth: int,
        top_k_experts: int,    
        stride: int = 1,         
        padding: int = 1,      
        kernel_size: int = 3,  
        qk_norm_attn: bool = True,
        flash_attention: bool = True,
    ):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.heads_param = heads
        self.num_tokens = num_tokens
        self.seq_length = seq_length
        self.num_experts = num_experts
        self.dim_head = dim_head
        self.dropout_rate = dropout
        self.ff_mult = ff_mult
        self.image_size = image_size
        self.channels = channels
        self.patch_size = patch_size
        self.top_k_for_moe_config = top_k_experts 
        self.qk_norm_attn = qk_norm_attn
        self.flash_attention = flash_attention

        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        
        assert image_height % patch_height == 0 and image_width % patch_width == 0, \
            'Image dimensions must be divisible by the patch size.'
        
        patch_dim = channels * patch_height * patch_width
        self.num_patches = (image_height // patch_height) * (image_width // patch_width)

      
        self.embed = nn.Embedding(num_tokens, dim)
        self.text_pos_emb = nn.Parameter(torch.randn(1, seq_length, dim))

       
        self.img_patch_embedding = nn.Sequential(
            Rearrange(
                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
                p1=patch_height,
                p2=patch_width,
            ),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        self.image_pos_emb = nn.Parameter(torch.randn(1, self.num_patches, dim))

        # Dense Encoder Layers
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                DenseEncoderLayer(
                    dim=dim,
                    heads=self.heads_param,
                    num_experts=num_experts,
                    dim_head=dim_head,
                    dropout=dropout,
                    ff_mult=ff_mult,
                    top_k=self.top_k_for_moe_config, 
                    qk_norm_attn=self.qk_norm_attn,
                    flash_attention=self.flash_attention,
                )
            )

        self.norm_input = nn.LayerNorm(dim) 
        self.norm_output = nn.LayerNorm(dim) 
        
       

    def forward(self, text: Tensor = None, image: Tensor = None, precomputed_embeddings: Tensor = None):
        if precomputed_embeddings is not None:
        
            tokens = precomputed_embeddings
        elif text is not None or image is not None:
            
            processed_tokens_list = []
            if text is not None:
                text_embeds = self.embed(text)
                current_seq_len = text.shape[1]
                text_embeds += self.text_pos_emb[:, :current_seq_len]
                processed_tokens_list.append(text_embeds)
            if image is not None:
                image_embeds = self.img_patch_embedding(image)
                image_embeds += self.image_pos_emb
                processed_tokens_list.append(image_embeds)
            
            if not processed_tokens_list:
                 raise ValueError("If not using precomputed_embeddings, at least one of text or image must be provided.")
            tokens = torch.cat(processed_tokens_list, dim=1) if len(processed_tokens_list) > 1 else processed_tokens_list[0]
        else:
            raise ValueError("Either precomputed_embeddings or (text or image) must be provided.")

        tokens = self.norm_input(tokens) 
    
        for layer in self.layers:
            tokens = layer(tokens)
        
        tokens = self.norm_output(tokens) 
       
        return tokens 
